package com.lxm.framework.mybatisplus.interceptor;

import com.baomidou.mybatisplus.annotation.DbType;
import com.baomidou.mybatisplus.core.MybatisParameterHandler;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.core.metadata.OrderItem;
import com.baomidou.mybatisplus.core.parser.ISqlParser;
import com.baomidou.mybatisplus.core.parser.SqlInfo;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.baomidou.mybatisplus.extension.plugins.pagination.DialectFactory;
import com.baomidou.mybatisplus.extension.plugins.pagination.DialectModel;
import com.baomidou.mybatisplus.extension.plugins.pagination.dialects.IDialect;
import com.baomidou.mybatisplus.extension.plugins.pagination.optimize.JsqlParserCountOptimize;
import com.lxm.framework.common.cache.caffeine.CaffeineCache;
import com.lxm.framework.mybatisplus.parser.CustomizedSqlParser;
import com.lxm.framework.mybatisplus.util.InterceptUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.RowBounds;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.util.*;
import java.util.stream.Collectors;

/**
 * @Author: Lys
 * @Date 2022/2/25
 * @Describe
 **/
@Slf4j
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
public class StatementInterceptor extends AbstractInterceptor implements Interceptor {

    @Override
    public Object plugin(Object target) {
        if (target instanceof StatementHandler) {
            return Plugin.wrap(target, this);
        }
        return target;
    }

    @Override
    public void setProperties(Properties properties) {
    }

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        return invoke(invocation);
    }

    @Override
    Object invoke(Invocation invocation) throws Exception {
        StatementHandler statementHandler = PluginUtils.realTarget(invocation.getTarget());
        MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
        MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
        BoundSql boundSql = (BoundSql) metaObject.getValue("delegate.boundSql");
        String originalSql = boundSql.getSql();
        Connection connection = (Connection) invocation.getArgs()[0];
        String caffeineKey = InterceptUtils.genKey(mappedStatement.getId(), originalSql);
        boolean hasCache = CaffeineCache.hasKey(caffeineKey);
        if (hasCache) {
            originalSql = CaffeineCache.read(caffeineKey, String.class);
        } else {
            boolean successParsed = false;
            try {
                boolean escape = InterceptUtils.escape(mappedStatement.getId());
                var parser = new CustomizedSqlParser(escape, connection, originalSql, null);
                var sqlInfo = parser.parser(originalSql);
                if (Objects.nonNull(sqlInfo) && StringUtils.isNotBlank(sqlInfo.getSql())) {
                    originalSql = sqlInfo.getSql();
                    successParsed = true;
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
            CaffeineCache.write(caffeineKey, originalSql);
            log.debug("lxm statement-interceptor , sql parse result : {} , sql : {}", successParsed, originalSql);
        }
        var sqlCommandType = mappedStatement.getSqlCommandType();
        if (SqlCommandType.SELECT.equals(sqlCommandType)) {
            originalSql = selectWithPage(metaObject, connection, mappedStatement, boundSql, originalSql);
            if (StringUtils.isBlank(originalSql)) {
                return null;
            }
        }
        return invocation.proceed();
    }

    private String selectWithPage(MetaObject metaObject, Connection connection, MappedStatement mappedStatement, BoundSql boundSql, String originalSql) {
        Object paramObj = boundSql.getParameterObject();
        IPage page = null;
        if (paramObj instanceof IPage) {
            page = (IPage) paramObj;
        } else if (page instanceof Map) {
            for (Object arg : ((Map) paramObj).values()) {
                if (arg instanceof IPage) {
                    page = (IPage) arg;
                    break;
                }
            }
        }
        //不需要分页的场合，如果 size 小于 0 返回结果集
        if (null == page || page.getSize() < 0) {
            return originalSql;
        }
        boolean orderBy = true;
        // 计算page总计
        if (page.getTotal() == 0) {
            SqlInfo sqlInfo = getSqlInfo(page.optimizeCountSql(), originalSql, metaObject);
            queryTotal(sqlInfo.getSql(), mappedStatement, boundSql, page, connection);
            if (page.getTotal() <= 0) {
                return null;
            }
            orderBy = sqlInfo.isOrderBy();
        }
        String buildSql = concatOrderBy(originalSql, page, orderBy);
        IDialect dialect = DialectFactory.getDialect(DbType.MYSQL);
        DialectModel model = dialect.buildPaginationSql(buildSql, page.offset(), page.getSize());
        Configuration configuration = mappedStatement.getConfiguration();
        List<ParameterMapping> mappings = new ArrayList<>(boundSql.getParameterMappings());
        Map<String, Object> additionalParameters = (Map) metaObject.getValue("delegate.boundSql.additionalParameters");
        model.consumers(mappings, configuration, additionalParameters);
        originalSql = model.getDialectSql();
        metaObject.setValue("delegate.boundSql.parameterMappings", mappings);
        metaObject.setValue("delegate.rowBounds.offset", RowBounds.NO_ROW_OFFSET);
        metaObject.setValue("delegate.rowBounds.limit", RowBounds.NO_ROW_LIMIT);
        return originalSql;
    }

    private String concatOrderBy(String originalSql, IPage page, boolean orderBy) {
        List<OrderItem> orders = page.orders();
        if (!orderBy || CollectionUtils.isEmpty(orders)) {
            return originalSql;
        }
        StringBuilder buildSql = new StringBuilder(originalSql);
        String ascStr = orders.stream().filter(OrderItem::isAsc).map(OrderItem::getColumn).collect(Collectors.joining(","));
        String descStr = orders.stream().filter(o -> !o.isAsc()).map(OrderItem::getColumn).collect(Collectors.joining(","));
        if (org.apache.commons.lang3.StringUtils.isNotBlank(ascStr) || org.apache.commons.lang3.StringUtils.isNotBlank(descStr)) {
            if (org.apache.commons.lang3.StringUtils.isNotBlank(ascStr)) {
                if (org.apache.commons.lang3.StringUtils.isNotBlank(descStr)) {
                    ascStr = ascStr + " ASC" + ", ";
                } else {
                    ascStr = ascStr + " ASC";
                }
            }
            if (org.apache.commons.lang3.StringUtils.isNotBlank(descStr)) {
                descStr = descStr + " DESC";
            }
            buildSql.append(" ORDER BY ").append(ascStr).append(descStr);
        }
        return buildSql.toString();
    }

    private SqlInfo getSqlInfo(boolean optimizeCountSql, String originalSql, MetaObject metaObject) {
        return optimizeCountSql ? ((ISqlParser) Optional.empty().orElseGet(JsqlParserCountOptimize::new)).parser(metaObject, originalSql) : SqlInfo.newInstance().setSql(String.format("SELECT COUNT(*) FROM (%s) TOTAL", originalSql));
    }

    private void queryTotal(String sql, MappedStatement mappedStatement, BoundSql boundSql, IPage page, Connection connection) {
        try (PreparedStatement statement = connection.prepareStatement(sql)) {
            MybatisParameterHandler parameterHandler = new MybatisParameterHandler(mappedStatement, boundSql.getParameterObject(), boundSql);
            parameterHandler.setParameters(statement);
            long total = 0;
            try (ResultSet resultSet = statement.executeQuery()) {
                if (resultSet.next()) {
                    total = resultSet.getLong(1);
                }
            }
            page.setTotal(total);
        } catch (Exception e) {
            throw ExceptionUtils.mpe("Error: Method queryTotal execution error.", e);
        }
    }
}
